library(tidyverse)
library(tidyboot)
library(ggplot2)
library(ggthemes)
library(knitr)
library(coda)
library(viridis)
library(here)
library(patchwork)
theme_set(theme_few())
estimate_mode <- function(s) {
  d <- density(s)
  return(d$x[which.max(d$y)])
}
hpd_upper <- function(s){
  m <- HPDinterval(mcmc(s))
  return(m["var1","upper"])
}
hpd_lower <- function(s){
  m <- HPDinterval(mcmc(s))
  return(m["var1","lower"])
}

count_summary_fn <- function(x) x %>%
  summarize(n = n()) %>%
  mutate(stat = n / sum(n))

mean_ci_funs <- list("ci_lower" = ci_lower, "mean" = mean, "ci_upper" = ci_upper)

Load human data

State

h_state <- read_csv(here("/data/clean_data_true_state.csv"))
## Parsed with column specification:
## cols(
##   id = col_character(),
##   condition_level = col_character(),
##   response = col_double(),
##   utt = col_character(),
##   exp = col_character(),
##   condition_name = col_character()
## )
h_state_summary <- h_state %>%
  rename(emo = exp, 
         state = response, 
         manipulation = condition_name, 
         manipulation_level = condition_level) %>%
  mutate(manipulation_level = ifelse(manipulation_level=="inf_goal", "inf", manipulation_level),
         manipulation_level = ifelse(manipulation_level=="soc_goal", "soc", manipulation_level))%>%
  group_by(manipulation, manipulation_level, utt, emo, state) %>%
  tidyboot(summary_function = count_summary_fn,
          statistics_functions = function(x) x %>%
          summarise(across(stat, mean_ci_funs, .names = "{.fn}"))) %>%
  mutate(condition = "human")
## `summarise()` has grouped output by 'manipulation', 'manipulation_level', 'utt', 'emo'. You can override using the `.groups` argument.
## `summarise()` has grouped output by '.id', 'manipulation', 'manipulation_level', 'utt', 'emo'. You can override using the `.groups` argument.
## `summarise()` has grouped output by 'manipulation', 'manipulation_level', 'utt', 'emo'. You can override using the `.groups` argument.

Goals

h_goal <- read_csv(here("/data/clean_data_goals.csv"))
## Parsed with column specification:
## cols(
##   id = col_character(),
##   condition_level = col_character(),
##   question = col_character(),
##   response = col_double(),
##   utt = col_character(),
##   exp = col_character(),
##   condition_name = col_character()
## )
h_goal_summary <- h_goal %>%
  rename(emo = exp,
         manipulation = condition_name, 
         manipulation_level = condition_level) %>%
  mutate(manipulation_level = ifelse(manipulation_level=="inf_goal", "inf", manipulation_level),
         manipulation_level = ifelse(manipulation_level=="soc_goal", "soc", manipulation_level))%>%
  group_by(manipulation, manipulation_level, utt, emo, question, response) %>%
  tidyboot(summary_function = count_summary_fn,
           statistics_functions = function(x) x %>%
           summarise(across(stat, mean_ci_funs, .names = "{.fn}")))%>%
  mutate(condition = "human")%>%
  rename(rating = 'response')
## `summarise()` has grouped output by 'manipulation', 'manipulation_level', 'utt', 'emo', 'question'. You can override using the `.groups` argument.
## `summarise()` has grouped output by '.id', 'manipulation', 'manipulation_level', 'utt', 'emo', 'question'. You can override using the `.groups` argument.
## `summarise()` has grouped output by 'manipulation', 'manipulation_level', 'utt', 'emo', 'question'. You can override using the `.groups` argument.
h_inf_summary <- h_goal_summary[h_goal_summary$question=="informational goal",]
h_soc_summary <- h_goal_summary[h_goal_summary$question=="social goal",]

Load model results

results_path <- "models/bda_results/"
model.files <- list.files(
      path = paste(here(), results_path, sep = "/"),
      pattern = "bda-M"
    )
df.m <- map_dfr(model.files, function(model.file){
    read_csv(here(paste(results_path, model.file, sep = "")),
             col_types = cols(
                      iter = col_double(),
                      model = col_character(),
                      chain = col_double(),
                      manipulation = col_character(),
                      manipulation_level = col_character(),
                      parameter = col_character(),
                      utt = col_character(),
                      emo = col_character(),
                      value = col_character(),
                      prob = col_double(),
                      score = col_double()
                    ))
})

Model Evaluation

Global parameters

df.m %>%
  filter(parameter == "parameter", is.na(emo)) %>%
  ggplot(., aes(x = prob))+
  geom_histogram(position = position_dodge())+
  facet_grid(cols = vars(utt), scales = "free_x")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

# ggsave(here("/models/figures/mb7_global_parameters.pdf"), width = 8, height = 5)

Summary

df.m %>%
  filter(parameter == "parameter", is.na(emo)) %>%
  group_by(utt) %>%
  summarize(
    MAP = estimate_mode(prob),
    cred_upper = hpd_upper(prob),
    cred_lower = hpd_lower(prob)
  ) -> df_parameter_summary


df_parameter_summary %>%
  kable(.)
utt MAP cred_upper cred_lower
goalExp 4.2829359 4.329513 1.0717456
goalScale 0.9991926 1.000000 1.0000000
speakerOptimality 0.9888722 1.029759 0.7771853

goalExp is bimodal! Plot the correlation between goalExp and speakerOptimality

df.m %>%
  filter(parameter == "parameter", is.na(emo)) %>%
  spread(key = utt, value = prob) %>%
  ggplot(., aes(x = goalExp, speakerOptimality))+
  geom_point()

Prior parameters

df.m %>%
  filter(parameter == "parameter", is.na(emo) == FALSE) %>%
  ggplot(., aes(x = prob))+
  geom_histogram(position = position_dodge())+
  facet_grid(cols = vars(utt, emo), rows = vars(value), scales = "free_x")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

# ggsave(here("/models/figures/mb7_prior_parameters.pdf"), width = 20, height = 5)

Summary

df.m %>%
  filter(parameter == "parameter", is.na(emo) == FALSE) %>%
  group_by(utt, emo, value) %>%
  summarize(
    MAP = estimate_mode(prob),
    cred_upper = hpd_upper(prob),
    cred_lower = hpd_lower(prob)
  ) %>%
  rename(prior = 'utt', level = 'emo') -> df_parameter_summary
## `summarise()` has grouped output by 'utt', 'emo'. You can override using the `.groups` argument.
df_parameter_summary %>%
  kable(.)
prior level value MAP cred_upper cred_lower
emoIsCommPrior communicative NA 0.8499244 0.9360882 0.6463909
emoIsCommPrior no_info NA 0.8279590 0.9519367 0.6332620
emoIsCommPrior noncommunicative NA 0.8414591 0.9611352 0.6874473
infGoalPrior inf mu 3.9434194 3.9981737 3.3974829
infGoalPrior inf sigma 2.3344302 2.9777432 1.8674075
infGoalPrior no_info mu 2.2421269 3.2555987 1.1814583
infGoalPrior no_info sigma 2.2753114 2.9799988 1.7112679
infGoalPrior soc mu 1.2022006 2.1809268 1.0016787
infGoalPrior soc sigma 2.0788183 2.9975383 1.7033229
socGoalPrior inf mu 2.8955958 3.7066733 2.5088735
socGoalPrior inf sigma 1.6824446 2.8136510 1.3032109
socGoalPrior no_info mu 2.8586973 3.7993390 2.6289184
socGoalPrior no_info sigma 1.5626901 2.6670844 1.2621433
socGoalPrior soc mu 3.9480441 3.9989514 3.6878253
socGoalPrior soc sigma 1.5456529 1.7311736 1.3418254
statePrior bad mu 1.0580634 1.7438720 1.0009146
statePrior bad sigma 2.9588315 2.9965795 2.5600415
statePrior good mu 4.5220204 5.8300373 2.5395105
statePrior good sigma 2.8940402 2.9947291 0.0067557
statePrior no_info mu 3.7426403 4.5412762 2.7743348
statePrior no_info sigma 2.9667587 2.9985271 2.6627207

State inference

df_state <- df.m %>%
  filter(parameter == "state") %>%
  mutate(state = as.numeric(value)) %>%
  select(-value)

df_state_summary <- df_state %>%
  group_by(manipulation, manipulation_level, utt, emo, state) %>%
  summarize(
    mean = estimate_mode(prob),
    ci_upper = hpd_upper(prob),
    ci_lower = hpd_lower(prob)
  ) %>%
  mutate(condition = "mb7")
## `summarise()` has grouped output by 'manipulation', 'manipulation_level', 'utt', 'emo'. You can override using the `.groups` argument.
# combine model with human
md_state_long <- rbind(df_state_summary, h_state_summary)
# md_state_wide <- md_state_long %>%
#   select(-n, -empirical_n, -empirical_stat) %>%
#   reshape2::dcast(., manipulation + manipulation_level + utt + emo + state ~ condition, value.var= c("mean", "ci_upper"))

line plot

md_state_long %>%
  unite("utt_emo", utt, emo) %>%
  ggplot(., aes(x=state, y=mean, group=condition, color=condition)) +
  geom_errorbar(aes(ymin=ci_lower, ymax=ci_upper), width=.1, position=position_dodge(0.05)) +
  geom_line(aes(linetype=condition))+
  geom_point(aes(shape=condition))+
  scale_color_manual(values=c('#E69F00', '#999999'))+
  ylim(0,1)+
  facet_grid(vars(utt_emo), vars(manipulation, manipulation_level))+
  theme(legend.position="bottom")

scatterplots

df_state_summary2 <- df_state %>%
  group_by(manipulation, manipulation_level, utt, emo, state) %>%
  summarize(
    MAP = estimate_mode(prob),
    cred_upper = hpd_upper(prob),
    cred_lower = hpd_lower(prob)
  ) 
## `summarise()` has grouped output by 'manipulation', 'manipulation_level', 'utt', 'emo'. You can override using the `.groups` argument.
md_state_wide <- left_join(df_state_summary2, h_state_summary)
## Joining, by = c("manipulation", "manipulation_level", "utt", "emo", "state")
md_state_wide %>%
  unite("utt_emo", utt, emo) %>%
  mutate(
    mean = ifelse(is.na(mean), 0, mean),
    ci_lower = ifelse(is.na(ci_lower), 0, ci_lower),
    ci_upper = ifelse(is.na(ci_upper), 0, ci_upper),
    state = factor(state)
  ) %>%
  #group_by(model) %>%
  summarize(
    mse = mean((MAP - mean)^2),
    r = cor(MAP, mean),
    r2 = r^2
  ) -> md_state_corr_table
## `summarise()` has grouped output by 'manipulation'. You can override using the `.groups` argument.
#write_csv(md_state_corr_table, "../state_correlations.csv")

md_state_wide %>%
  unite("utt_emo", utt, emo) %>%
  mutate(
    mean = ifelse(is.na(mean), 0, mean),
    ci_lower = ifelse(is.na(ci_lower), 0, ci_lower),
    ci_upper = ifelse(is.na(ci_upper), 0, ci_upper),
    state = factor(state)
  ) %>%
  ggplot(., aes( x = MAP, xmin = cred_lower, xmax = cred_upper,
                      y = mean, ymin = ci_lower, ymax = ci_upper,
                 shape = utt_emo, color = state))+
  geom_abline(intercept = 0, slope = 1, alpha = 0.3, linetype = 2)+
  geom_linerange()+
  geom_text(data = md_state_corr_table, x = 0.15, y = 0.93,
            aes(label = paste("r=", round(r, 2), sep= "")),
            inherit.aes = F)+
  ggstance::geom_linerangeh()+
  geom_point()+
  scale_color_viridis(discrete = T)+
  #xlim(0, 1)+
  #ylim(0, 1)+
  coord_fixed()+
  facet_wrap(vars(manipulation, manipulation_level), ncol = 3)+
  scale_y_continuous(limits = c(0, 1), breaks = c(0, 1))+
  scale_x_continuous(limits = c(0, 1), breaks = c(0, 1))+
  theme(legend.position = 'right')+
  labs(
    x = "Model Predicted Probability",
    y = "Human Proportion Selected"
  )

# ggsave(filename = "bda_results/figs/bda_scatters_state_21models_cogsci.pdf", width = 18, height = 5)

correlation table

md_state_corr_table %>%
  kable()
manipulation manipulation_level mse r r2
emoIsComm_manipulation comm 0.0096288 0.8160185 0.6658861
emoIsComm_manipulation no_info 0.0120118 0.8083581 0.6534428
emoIsComm_manipulation non_comm 0.0257575 0.6489064 0.4210795
goal_manipulation inf 0.0085907 0.8750569 0.7657245
goal_manipulation no_info 0.0083104 0.8875054 0.7876659
goal_manipulation soc 0.0120027 0.8369155 0.7004275
state_manipulation bad 0.0142679 0.8000517 0.6400828
state_manipulation good 0.0114836 0.7712907 0.5948893
state_manipulation no_info 0.0130273 0.7678206 0.5895485

infGoal inference

df_goal <- df.m %>%
  filter(parameter %in% c("socGoal", "infGoal")) %>%
  mutate(rating = as.numeric(value)) %>%
  select(-value)

df_goal_summary <- df_goal %>%
  group_by(manipulation, manipulation_level, parameter, utt, emo, rating) %>%
  summarize(
    mean = estimate_mode(prob),
    ci_upper = hpd_upper(prob),
    ci_lower = hpd_lower(prob)
  ) %>%
  mutate(condition = "mb7")
## `summarise()` has grouped output by 'manipulation', 'manipulation_level', 'parameter', 'utt', 'emo'. You can override using the `.groups` argument.
df_inf_summary <- df_goal_summary[df_goal_summary$parameter=="infGoal",]
df_soc_summary <- df_goal_summary[df_goal_summary$parameter=="socGoal",]

# combine model with human
md_inf_long <- rbind(df_inf_summary, h_inf_summary)

line plot

md_inf_long %>%
  unite("utt_emo", utt, emo) %>%
  ggplot(., aes(x=rating, y=mean, group=condition, color=condition)) +
  geom_errorbar(aes(ymin=ci_lower, ymax=ci_upper), width=.1, position=position_dodge(0.05)) +
  geom_line(aes(linetype=condition))+
  geom_point(aes(shape=condition))+
  scale_color_manual(values=c('#E69F00', '#999999'))+
  ylim(0,1)+
  facet_grid(vars(utt_emo), vars(manipulation, manipulation_level))+
  theme(legend.position="bottom")

scatterplots

df_goal_summary2 <- df_goal %>%
  group_by(manipulation, manipulation_level, parameter, utt, emo, rating) %>%
  summarize(
    MAP = estimate_mode(prob),
    cred_upper = hpd_upper(prob),
    cred_lower = hpd_lower(prob)
  ) %>%
  mutate(question = factor(parameter, levels = c("infGoal", "socGoal"),
                            labels = c("informational goal", "social goal")))
## `summarise()` has grouped output by 'manipulation', 'manipulation_level', 'parameter', 'utt', 'emo'. You can override using the `.groups` argument.
md_goals_wide <- left_join(
  df_goal_summary2, h_goal_summary
)
## Joining, by = c("manipulation", "manipulation_level", "utt", "emo", "rating", "question")
md_goals_wide %>%
  unite("utt_emo", utt, emo) %>%
  mutate(
    mean = ifelse(is.na(mean), 0, mean),
    ci_lower = ifelse(is.na(ci_lower), 0, ci_lower),
    ci_upper = ifelse(is.na(ci_upper), 0, ci_upper)
  ) %>%
  #group_by(model, question) %>%
  group_by(manipulation, manipulation_level, question) %>%
  summarize(
    n = n(),
    mse = mean((MAP - mean)^2),
    r = cor(MAP, mean),
    r2 = r^2
  ) -> md_goal_corr_table
## `summarise()` has grouped output by 'manipulation', 'manipulation_level'. You can override using the `.groups` argument.
# write_csv(md_goal_corr_table, "../goal_correlations.csv")

md_goals_wide %>%
  filter(parameter=="infGoal") %>%
  unite("utt_emo", utt, emo) %>%
  mutate(
    mean = ifelse(is.na(mean), 0, mean),
    ci_lower = ifelse(is.na(ci_lower), 0, ci_lower),
    ci_upper = ifelse(is.na(ci_upper), 0, ci_upper),
    rating = factor(rating)
  ) %>%
  ggplot(., aes( x = MAP, xmin = cred_lower, xmax = cred_upper,
                      y = mean, ymin = ci_lower, ymax = ci_upper,
                 shape = utt_emo, color = rating))+
  geom_abline(intercept = 0, slope = 1, alpha = 0.3, linetype = 2)+
  geom_linerange()+
  geom_text(data = md_goal_corr_table[md_goal_corr_table$question=="informational goal",], x = 0.15, y = 0.96,
            aes(label = paste("r=", round(r, 2), sep= "")),
            inherit.aes = F)+
  ggstance::geom_linerangeh()+
  geom_point()+
  coord_fixed()+
  #facet_grid(question~model)+
  facet_wrap(vars(manipulation, manipulation_level, nrows = 3))+
  scale_y_continuous(limits = c(0, 1), breaks = c(0, 1))+
  scale_x_continuous(limits = c(0, 1), breaks = c(0, 1))+
  theme(legend.position = 'right')+
  labs(
    x = "Model Predicted Probability",
    y = "Human Proportion Selected"
  )

#ggsave(filename = "bda_results/figs/bda_scatters_goal_21models_cogsci.pdf", width = 24, height = 4.5)

correlation table

md_goal_corr_table[md_goal_corr_table$question=="informational goal",] %>%
  kable()
manipulation manipulation_level question n mse r r2
emoIsComm_manipulation comm informational goal 16 0.0116011 0.7743644 0.5996402
emoIsComm_manipulation no_info informational goal 16 0.0195935 0.7640446 0.5837642
emoIsComm_manipulation non_comm informational goal 16 0.0244856 0.7458098 0.5562322
goal_manipulation inf informational goal 16 0.0252854 0.7800752 0.6085173
goal_manipulation no_info informational goal 16 0.0277944 0.6747127 0.4552373
goal_manipulation soc informational goal 16 0.0176360 0.7330459 0.5373563
state_manipulation bad informational goal 16 0.0230603 0.7758643 0.6019655
state_manipulation good informational goal 16 0.0159284 0.6539092 0.4275973
state_manipulation no_info informational goal 16 0.0303979 0.6704201 0.4494631

socGoal inference

# combine model with human
md_soc_long <- rbind(df_soc_summary, h_soc_summary)

line plot

md_soc_long %>%
  unite("utt_emo", utt, emo) %>%
  ggplot(., aes(x=rating, y=mean, group=condition, color=condition)) +
  geom_errorbar(aes(ymin=ci_lower, ymax=ci_upper), width=.1, position=position_dodge(0.05)) +
  geom_line(aes(linetype=condition))+
  geom_point(aes(shape=condition))+
  scale_color_manual(values=c('#E69F00', '#999999'))+
  ylim(0,1)+
  facet_grid(vars(utt_emo), vars(manipulation, manipulation_level))+
  theme(legend.position="bottom")

scatterplots

md_goals_wide %>%
  filter(parameter=="socGoal") %>%
  unite("utt_emo", utt, emo) %>%
  mutate(
    mean = ifelse(is.na(mean), 0, mean),
    ci_lower = ifelse(is.na(ci_lower), 0, ci_lower),
    ci_upper = ifelse(is.na(ci_upper), 0, ci_upper),
    rating = factor(rating)
  ) %>%
  ggplot(., aes( x = MAP, xmin = cred_lower, xmax = cred_upper,
                      y = mean, ymin = ci_lower, ymax = ci_upper,
                 shape = utt_emo, color = rating))+
  geom_abline(intercept = 0, slope = 1, alpha = 0.3, linetype = 2)+
  geom_linerange()+
  geom_text(data = md_goal_corr_table[md_goal_corr_table$question=="social goal",], x = 0.15, y = 0.96,
            aes(label = paste("r=", round(r, 2), sep= "")),
            inherit.aes = F)+
  ggstance::geom_linerangeh()+
  geom_point()+
  coord_fixed()+
  #facet_grid(question~model)+
  facet_wrap(vars(manipulation, manipulation_level, nrows = 3))+
  scale_y_continuous(limits = c(0, 1), breaks = c(0, 1))+
  scale_x_continuous(limits = c(0, 1), breaks = c(0, 1))+
  theme(legend.position = 'right')+
  labs(
    x = "Model Predicted Probability",
    y = "Human Proportion Selected"
  )

### correlation table

md_goal_corr_table[md_goal_corr_table$question=="social goal",] %>%
  kable()
manipulation manipulation_level question n mse r r2
emoIsComm_manipulation comm social goal 16 0.0214459 0.4489083 0.2015187
emoIsComm_manipulation no_info social goal 16 0.0160695 0.5946875 0.3536533
emoIsComm_manipulation non_comm social goal 16 0.0121807 0.6924029 0.4794217
goal_manipulation inf social goal 16 0.0102452 0.7584241 0.5752072
goal_manipulation no_info social goal 16 0.0136562 0.7448383 0.5547841
goal_manipulation soc social goal 16 0.0142403 0.7934307 0.6295323
state_manipulation bad social goal 16 0.0129808 0.7009897 0.4913865
state_manipulation good social goal 16 0.0134526 0.7286970 0.5309993
state_manipulation no_info social goal 16 0.0158089 0.6940490 0.4817040